{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook allows you to do real-time (\"streaming\") speech recognition using audio recorded from your microphone. This notebook shows how to use a NeMo chunk-aware FastConformer model with caching enabled.\n",
    "\n",
    "## Installation\n",
    "\n",
    "The notebook requires PyAudio library, which is used to capture an audio stream from your machine. This means that you need to run this notebook locally. This notebook will not be able to record your audio if you run it in Google Colab or in a Docker container.\n",
    "\n",
    "For Ubuntu, please run the following commands to install it:\n",
    "\n",
    "```\n",
    "sudo apt install python3-pyaudio\n",
    "pip install pyaudio\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg portaudio19-dev\n",
    "!pip install text-unidecode\n",
    "!pip install pyaudio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ## Uncomment this cell to install NeMo if it has not been installed\n",
    "# BRANCH = 'main'\n",
    "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import dependencies\n",
    "import copy\n",
    "import time\n",
    "import pyaudio as pa\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from omegaconf import OmegaConf, open_dict\n",
    "\n",
    "import nemo.collections.asr as nemo_asr\n",
    "from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE\n",
    "from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer\n",
    "from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis\n",
    "\n",
    "# specify sample rate we will use for recording audio\n",
    "SAMPLE_RATE = 16000 # Hz"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cache-aware streaming Fastconformer\n",
    "In this tutorial, we will do streaming transcription using NeMo models that were specially trained for use in streaming applications. These models are described in the paper released by the NeMo team: [*Noroozi et al.* \"Stateful FastConformer with Cache-based Inference for Streaming Automatic Speech Recognition](https://arxiv.org/abs/2312.17279)\" (accepted to ICASSP 2024).\n",
    "\n",
    "These models have the following features:\n",
    "* They were trained such that at each timestep, the decoder (either RNNT or CTC) would receive a limited amount of context on the left and (most importantly) the right side. Keeping the right side context small means that in a real time streaming scenario, we do not need to keep recording for very long before we are able to compute the output token at that timestep - thus we are able to get transcriptions with a low latency.\n",
    "* The model implementation has **caching** enabled, meaning we do not need to recalculate activations that were obtained in previous timesteps, thus reducing latency further.\n",
    "\n",
    "\n",
    "## Model checkpoints\n",
    "The following checkpoints of these models are currently available, and are compatible with this notebook. The meaning of \"lookahead\" and \"chunk size\" is described in the following section.\n",
    "\n",
    "1) [`stt_en_fastconformer_hybrid_large_streaming_80ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms) - 80ms lookahead / 160ms chunk size\n",
    "\n",
    "2) [`stt_en_fastconformer_hybrid_large_streaming_480ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_480ms) - 480ms lookahead / 540ms chunk size\n",
    "\n",
    "3) [`stt_en_fastconformer_hybrid_large_streaming_1040ms`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms) - 1040ms lookahead / 1120ms chunk size\n",
    "\n",
    "4) [`stt_en_fastconformer_hybrid_large_streaming_multi`](https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi) - 0ms, 80ms, 480ms, 1040ms lookahead / 80ms, 160ms, 540ms, 1120ms chunk size\n",
    "\n",
    "## Model inference explanation\n",
    "We run inference by continuously recording our audio in chunks, and feeding the chunks into the chosen ASR model. In this notebook we use `pyaudio` to open an audio input stream, and pass the audio to a `stream_callback` function every \"chunk-sized\" number of seconds. In the `stream_callback` function, we pass the audio signal to a `transcribe` function (which we will specify in this notebook), and print the resulting transcription.\n",
    "\n",
    "As mentioned, the \"chunk size\" is the duration of audio that we feed into the ASR model at a time (and we keep doing this continuously, to allow for real-time, streaming speech recognition).\n",
    "\n",
    "\"Lookahead\" size is the \"chunk size\" minus the duration of a single output timestep from the decoder. For FastConformer models, the duration of an output timestep is always 80ms, hence in this notebook always `lookahead size = chunk size - 80 ms`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model selection\n",
    "In the next cell, you can select which pretrained `model_name` and `lookahead_size` you would like to try.\n",
    "\n",
    "Additionally, note that all of the available models are [Hybrid RNNT-CTC models](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/models.html#hybrid-transducer-ctc). Inference is by default done using the RNNT decoder (which tends to produce a higher transcription accuracy), but you may choose to use the CTC decoder instead. For this, we also provide a `decoder_type` variable in the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You may wish to try different values of model_name and lookahead_size\n",
    "\n",
    "# Choose a the name of a model to use.\n",
    "# Currently available options:\n",
    "# 1) \"stt_en_fastconformer_hybrid_large_streaming_multi\"\n",
    "# 2) \"stt_en_fastconformer_hybrid_large_streaming_80ms\"\n",
    "# 3) \"stt_en_fastconformer_hybrid_large_streaming_480ms\"\n",
    "# 4) \"stt_en_fastconformer_hybrid_large_streaming_1040ms\"\n",
    "\n",
    "model_name = \"stt_en_fastconformer_hybrid_large_streaming_multi\"\n",
    "\n",
    "# Specify the lookahead_size.\n",
    "# If model_name == \"stt_en_fastconformer_hybrid_large_streaming_multi\" then\n",
    "# lookahead_size can be 0, 80, 480 or 1040 (ms)\n",
    "# Else, lookahead_size should be whatever is written in the model_name:\n",
    "# \"stt_en_fastconformer_hybrid_large_streaming_<lookahead_size>ms\"\n",
    "\n",
    "lookahead_size = 80 # in milliseconds\n",
    "\n",
    "# Specify the decoder to use.\n",
    "# Can be \"rnnt\" or \"ctc\"\n",
    "decoder_type = \"rnnt\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model set-up\n",
    "Next we:\n",
    "* set up the `asr_model` according to the chosen `model_name` and `lookahead_size`\n",
    "* make sure we use the specified `decoder_type`\n",
    "* make sure the model's decoding strategy has suitable parameters\n",
    "* instantiate a `CacheAwareStreamingAudioBuffer`\n",
    "* get some parameters to use as the initial cache state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# setting up model and validating the choice of model_name and lookahead size\n",
    "asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)\n",
    "\n",
    "\n",
    "# specify ENCODER_STEP_LENGTH (which is 80 ms for FastConformer models)\n",
    "ENCODER_STEP_LENGTH = 80 # ms\n",
    "\n",
    "# update att_context_size if using multi-lookahead model\n",
    "# (for single-lookahead models, the default context size will be used and the\n",
    "# `lookahead_size` variable will be ignored)\n",
    "if model_name == \"stt_en_fastconformer_hybrid_large_streaming_multi\":\n",
    "    # check that lookahead_size is one of the valid ones\n",
    "    if lookahead_size not in [0, 80, 480, 1040]:\n",
    "        raise ValueError(\n",
    "            f\"specified lookahead_size {lookahead_size} is not one of the \"\n",
    "            \"allowed lookaheads (can select 0, 80, 480 or 1040 ms)\"\n",
    "        )\n",
    "\n",
    "    # update att_context_size\n",
    "    left_context_size = asr_model.encoder.att_context_size[0]\n",
    "    asr_model.encoder.set_default_att_context_size([left_context_size, int(lookahead_size / ENCODER_STEP_LENGTH)])\n",
    "\n",
    "\n",
    "# make sure we use the specified decoder_type\n",
    "asr_model.change_decoding_strategy(decoder_type=decoder_type)\n",
    "\n",
    "# make sure the model's decoding strategy is optimal\n",
    "decoding_cfg = asr_model.cfg.decoding\n",
    "with open_dict(decoding_cfg):\n",
    "    # save time by doing greedy decoding and not trying to record the alignments\n",
    "    decoding_cfg.strategy = \"greedy\"\n",
    "    decoding_cfg.preserve_alignments = False\n",
    "    if hasattr(asr_model, 'joint'):  # if an RNNT model\n",
    "        # restrict max_symbols to make sure not stuck in infinite loop\n",
    "        decoding_cfg.greedy.max_symbols = 10\n",
    "        # sensible default parameter, but not necessary since batch size is 1\n",
    "        decoding_cfg.fused_batch_size = -1\n",
    "    asr_model.change_decoding_strategy(decoding_cfg)\n",
    "\n",
    "\n",
    "# set model to eval mode\n",
    "asr_model.eval()\n",
    "\n",
    "\n",
    "# get parameters to use as the initial cache state\n",
    "cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(\n",
    "    batch_size=1\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transcribing a single chunk\n",
    "In the following code block we specify the `transcribe_chunk` function that transcribes a single chunk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init params we will use for streaming\n",
    "previous_hypotheses = None\n",
    "pred_out_stream = None\n",
    "step_num = 0\n",
    "pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]\n",
    "# cache-aware models require some small section of the previous processed_signal to\n",
    "# be fed in at each timestep - we initialize this to a tensor filled with zeros\n",
    "# so that we will do zero-padding for the very first chunk(s)\n",
    "num_channels = asr_model.cfg.preprocessor.features\n",
    "cache_pre_encode = torch.zeros((1, num_channels, pre_encode_cache_size), device=asr_model.device)\n",
    "\n",
    "\n",
    "# helper function for extracting transcriptions\n",
    "def extract_transcriptions(hyps):\n",
    "    \"\"\"\n",
    "        The transcribed_texts returned by CTC and RNNT models are different.\n",
    "        This method would extract and return the text section of the hypothesis.\n",
    "    \"\"\"\n",
    "    if isinstance(hyps[0], Hypothesis):\n",
    "        transcriptions = []\n",
    "        for hyp in hyps:\n",
    "            transcriptions.append(hyp.text)\n",
    "    else:\n",
    "        transcriptions = hyps\n",
    "    return transcriptions\n",
    "\n",
    "# define functions to init audio preprocessor and to\n",
    "# preprocess the audio (ie obtain the mel-spectrogram)\n",
    "def init_preprocessor(asr_model):\n",
    "    cfg = copy.deepcopy(asr_model._cfg)\n",
    "    OmegaConf.set_struct(cfg.preprocessor, False)\n",
    "\n",
    "    # some changes for streaming scenario\n",
    "    cfg.preprocessor.dither = 0.0\n",
    "    cfg.preprocessor.pad_to = 0\n",
    "    cfg.preprocessor.normalize = \"None\"\n",
    "    \n",
    "    preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)\n",
    "    preprocessor.to(asr_model.device)\n",
    "    \n",
    "    return preprocessor\n",
    "\n",
    "preprocessor = init_preprocessor(asr_model)\n",
    "\n",
    "def preprocess_audio(audio, asr_model):\n",
    "    device = asr_model.device\n",
    "\n",
    "    # doing audio preprocessing\n",
    "    audio_signal = torch.from_numpy(audio).unsqueeze_(0).to(device)\n",
    "    audio_signal_len = torch.Tensor([audio.shape[0]]).to(device)\n",
    "    processed_signal, processed_signal_length = preprocessor(\n",
    "        input_signal=audio_signal, length=audio_signal_len\n",
    "    )\n",
    "    return processed_signal, processed_signal_length\n",
    "\n",
    "\n",
    "def transcribe_chunk(new_chunk):\n",
    "    \n",
    "    global cache_last_channel, cache_last_time, cache_last_channel_len\n",
    "    global previous_hypotheses, pred_out_stream, step_num\n",
    "    global cache_pre_encode\n",
    "    \n",
    "    # new_chunk is provided as np.int16, so we convert it to np.float32\n",
    "    # as that is what our ASR models expect\n",
    "    audio_data = new_chunk.astype(np.float32)\n",
    "    audio_data = audio_data / 32768.0\n",
    "\n",
    "    # get mel-spectrogram signal & length\n",
    "    processed_signal, processed_signal_length = preprocess_audio(audio_data, asr_model)\n",
    "     \n",
    "    # prepend with cache_pre_encode\n",
    "    processed_signal = torch.cat([cache_pre_encode, processed_signal], dim=-1)\n",
    "    processed_signal_length += cache_pre_encode.shape[1]\n",
    "    \n",
    "    # save cache for next time\n",
    "    cache_pre_encode = processed_signal[:, :, -pre_encode_cache_size:]\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        (\n",
    "            pred_out_stream,\n",
    "            transcribed_texts,\n",
    "            cache_last_channel,\n",
    "            cache_last_time,\n",
    "            cache_last_channel_len,\n",
    "            previous_hypotheses,\n",
    "        ) = asr_model.conformer_stream_step(\n",
    "            processed_signal=processed_signal,\n",
    "            processed_signal_length=processed_signal_length,\n",
    "            cache_last_channel=cache_last_channel,\n",
    "            cache_last_time=cache_last_time,\n",
    "            cache_last_channel_len=cache_last_channel_len,\n",
    "            keep_all_outputs=False,\n",
    "            previous_hypotheses=previous_hypotheses,\n",
    "            previous_pred_out=pred_out_stream,\n",
    "            drop_extra_pre_encoded=None,\n",
    "            return_transcription=True,\n",
    "        )\n",
    "    \n",
    "    final_streaming_tran = extract_transcriptions(transcribed_texts)\n",
    "    step_num += 1\n",
    "    \n",
    "    return final_streaming_tran[0]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple streaming with microphone\n",
    "We use `pyaudio` to record audio from an input audio device on your local machine. We use a `stream_callback` which will be called every `frames_per_buffer` number of frames, and conduct the transcription, which will be printed in the output of the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# calculate chunk_size in milliseconds\n",
    "chunk_size = lookahead_size + ENCODER_STEP_LENGTH\n",
    "\n",
    "p = pa.PyAudio()\n",
    "print('Available audio input devices:')\n",
    "input_devices = []\n",
    "for i in range(p.get_device_count()):\n",
    "    dev = p.get_device_info_by_index(i)\n",
    "    if dev.get('maxInputChannels'):\n",
    "        input_devices.append(i)\n",
    "        print(i, dev.get('name'))\n",
    "\n",
    "if len(input_devices):\n",
    "    dev_idx = -2\n",
    "    while dev_idx not in input_devices:\n",
    "        print('Please type input device ID:')\n",
    "        dev_idx = int(input())\n",
    "\n",
    "    def callback(in_data, frame_count, time_info, status):\n",
    "        signal = np.frombuffer(in_data, dtype=np.int16)\n",
    "        text = transcribe_chunk(signal)\n",
    "        print(text, end='\\r')\n",
    "        return (in_data, pa.paContinue)\n",
    "\n",
    "    stream = p.open(format=pa.paInt16,\n",
    "                    channels=1,\n",
    "                    rate=SAMPLE_RATE,\n",
    "                    input=True,\n",
    "                    input_device_index=dev_idx,\n",
    "                    stream_callback=callback,\n",
    "                    frames_per_buffer=int(SAMPLE_RATE * chunk_size / 1000) - 1\n",
    "                   )\n",
    "\n",
    "    print('Listening...')\n",
    "\n",
    "    stream.start_stream()\n",
    "    \n",
    "    # Interrupt kernel and then speak for a few more words to exit the pyaudio loop !\n",
    "    try:\n",
    "        while stream.is_active():\n",
    "            time.sleep(0.1)\n",
    "    finally:        \n",
    "        stream.stop_stream()\n",
    "        stream.close()\n",
    "        p.terminate()\n",
    "\n",
    "        print()\n",
    "        print(\"PyAudio stopped\")\n",
    "    \n",
    "else:\n",
    "    print('ERROR: No audio input device found.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
